package priv.lhy.remote.rule;

import com.alibaba.cloud.nacos.ribbon.NacosServer;
import com.netflix.client.config.IClientConfig;
import com.netflix.loadbalancer.AbstractLoadBalancerRule;
import com.netflix.loadbalancer.ILoadBalancer;
import com.netflix.loadbalancer.Server;
import org.springframework.stereotype.Component;
import priv.lhy.remote.aop.PassParameters;

import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;

/**
 * @author :lihy
 * @date :2021/4/26 10:12
 * description :
 **/
@Component
public class TagRule extends AbstractLoadBalancerRule {
    //定义一个原子类，以保证原子性
    private AtomicInteger atomicInteger = new AtomicInteger(0);

    @Override
    public void initWithNiwsConfig(IClientConfig iClientConfig) {

    }

    @Override
    public Server choose(Object key) {
        return choose(getLoadBalancer(), key);
    }

    public Server choose(ILoadBalancer lb, Object key) {
        if (lb == null) {
            return null;
        }
        Map<String, String> headers = PassParameters.get();
        String tag = headers.get("tag");
        Server server = null;
        //获取服务
        List<Server> allServers = null;
        List<Server> reachableServers = null;

        if (null != tag) {
            allServers = lb.getAllServers()
                    .stream()
                    .filter(s -> ((NacosServer)s).getMetadata().get("tag").equals(tag))
                    .collect(Collectors.toList());
            reachableServers = lb.getReachableServers()
                    .stream()
                    .filter(s -> ((NacosServer)s).getMetadata().get("tag").equals(tag))
                    .collect(Collectors.toList());
        } else {
            allServers = lb.getAllServers();
            reachableServers = lb.getReachableServers();
        }

        int allServersSize = allServers.size();
        int reachableServersSize = reachableServers.size();
        //如果获取的服务list都为0就返回null
        if (allServersSize == 0 || reachableServersSize == 0) {
            return null;
        }
        //获取服务下标
        int next = getServerIndex(allServersSize);

        //获取服务
        server = reachableServers.get(next);

        //如果服务为空直接跳过下面的
        if (server == null) {
            return null;
        }

        //如果获取到的这个服务是活着的就返回
        if (server.isAlive()) {
            return server;
        }

        return server;
    }

    //获取服务下标，为了保证原子性，使用了CAS
    public int getServerIndex(int allServersSize) {
        //自旋锁
        for (; ; ) {
            //获取当前值
            int current = this.atomicInteger.get();
            //设置期望值
            int next = (current + 1) % allServersSize;
            //调用Native方法compareAndSet，执行CAS操作
            if (this.atomicInteger.compareAndSet(current, next))
                //成功后才会返回期望值，否则无线循环
                return next;
        }
    }
}
